{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow教程-keras构建RNN\n",
    "循环神经网络是一种在时间维度上进行迭代的神经网络，在建模序列数据方面有着优越的性能。\n",
    "TensorFlow2中包含了主流的rnn网络实现，以Keras RNN API的方式提供调用。其特性如下：\n",
    "- 易用性： 内置tf.keras.layers.RNN，tf.keras.layers.LSTM，tf.keras.layers.GRU， 可以快速构建RNN模块。\n",
    "- 易于定制：可以使用自定义操作循环来构建RNN模块，并通过tf.keras.layers.RNN调用。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/doit/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "import collections\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 构建一个简单模型\n",
    "keras中内置了三个RNN层\n",
    "- tf.keras.layers.SimpleRNN，普通RNN网络。\n",
    "- tf.keras.layers.GRU，门控循环神经网络。\n",
    "- tf.keras.layers.LSTM，长短期记忆神经网络。\n",
    "下面是一个循环神经网络的例子，它使用LSTM层来处理输入的词嵌入，LSTM迭代的次数与词个个数一致\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (None, None, 64)          64000     \n",
      "_________________________________________________________________\n",
      "lstm (LSTM)                  (None, 128)               98816     \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 164,106\n",
      "Trainable params: 164,106\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential()\n",
    "# input_dim是词典大小， output_dim是词嵌入维度\n",
    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
    "# 添加lstm层，其会输出最后一个时间步的输出\n",
    "model.add(layers.LSTM(128))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 输出和状态\n",
    "默认情况下RNN层输出最后一个时间步的输出，输出的形状为(batch_size, units), 其中unit是传给层的构造参数。\n",
    "如果要返回rnn每个时间步的序列，需要置return_sequence=True, 此时输出的形状为(batch_size, time_steps, units)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, None, 64)          64000     \n",
      "_________________________________________________________________\n",
      "gru (GRU)                    (None, None, 128)         74496     \n",
      "_________________________________________________________________\n",
      "simple_rnn (SimpleRNN)       (None, 128)               32896     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 172,682\n",
      "Trainable params: 172,682\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
    "# 返回整个rnn序列的输出\n",
    "model.add(layers.GRU(128, return_sequences=True))\n",
    "model.add(layers.SimpleRNN(128))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "model.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "RNN层可以返回最后一个时间步的状态。\n",
    "\n",
    "返回的状态可以用于恢复RNN或初始化另一个RNN。此设置通常在seq2seq模型中用到，其将编码器的最终状态作为解码器的初始状态。\n",
    "\n",
    "要使RNN层返回最终的状态需要在创建图层时置return_state=True。请注意，LSTM有两个状态向量，而GRU只有一个。\n",
    "\n",
    "要配置图层的初始状态， 需要网络构建中传入initial_state。其中状态的维度必须与图层的unit大小一样。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_2 (InputLayer)            [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding_2 (Embedding)         (None, None, 64)     64000       input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "embedding_3 (Embedding)         (None, None, 64)     64000       input_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "encoder (LSTM)                  [(None, 64), (None,  33024       embedding_2[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "decoder (LSTM)                  (None, 64)           33024       embedding_3[0][0]                \n",
      "                                                                 encoder[0][1]                    \n",
      "                                                                 encoder[0][2]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10)           650         decoder[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 194,698\n",
      "Trainable params: 194,698\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "encoder_vocab = 1000\n",
    "decoder_vocab = 2000\n",
    "\n",
    "\n",
    "# 编码层\n",
    "encode_input = layers.Input(shape=(None, ))\n",
    "encode_emb = layers.Embedding(input_dim=encoder_vocab, output_dim=64)(encode_input)\n",
    "# 同时返回状态\n",
    "encode_out, state_h, state_c = layers.LSTM(64, return_state=True, name='encoder')(encode_emb)\n",
    "encode_state = [state_h, state_c]\n",
    "\n",
    "# 解码层\n",
    "decode_input =layers.Input(shape=(None, ))\n",
    "decode_emb = layers.Embedding(input_dim=encoder_vocab, output_dim=64)(decode_input)\n",
    "# 编码器的最终状态， 作为解码器的初始状态\n",
    "decode_out = layers.LSTM(64, name='decoder')(decode_emb, initial_state=encode_state)\n",
    "output = layers.Dense(10, activation='softmax')(decode_out)\n",
    "\n",
    "model = tf.keras.Model([encode_input, decode_input], output)\n",
    "model.summary()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 RNN层和RNN Cell\n",
    "除了内置的RNN层以外， RNN API还提供单元级API。与可以处理整批次的输入序列不同， RNN Cell仅能处理单个时间步的数据。\n",
    "如果想处理整批次的输入数据，需要将RNN Cell包含在tf.keras.layers.RNN中， 如：RNN(LSTMCell(10)\n",
    "\n",
    "同效果上说， RNN(LSTMCell(10))等价于LSTM(10)。但内置的GRU和LSTM层可以使用CuDNN，以提高计算性能。\n",
    "\n",
    "下面是三个内置的RNN单元， 以及其对应的RNN层。\n",
    "- tf.keras.layers.SimpleRNNCell对应于SimpleRNN图层。\n",
    "- tf.keras.layers.GRUCell对应于GRU图层。\n",
    "- tf.keras.layers.LSTMCell对应于LSTM图层。\n",
    "\n",
    "注： tf.keras.layers.RNN类可以使为研究实现自定义RNN体系结构变得非常容易。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_4 (Embedding)      (None, None, 64)          64000     \n",
      "_________________________________________________________________\n",
      "rnn (RNN)                    (None, 64)                33024     \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 10)                650       \n",
      "=================================================================\n",
      "Total params: 97,674\n",
      "Trainable params: 97,674\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential()\n",
    "# input_dim是词典大小， output_dim是词嵌入维度\n",
    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
    "# 添加lstm层，其会输出最后一个时间步的输出\n",
    "model.add(layers.RNN(layers.LSTMCell(64)))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 跨批次状态\n",
    "在处理超长序列（有可能无限长）时，可能需要使用到跨批次状态的模式。\n",
    "\n",
    "通常，每次看到新的批次时， 都会重置RNN层的内部状态（state，非权重。即该层看到的每个样本都独立于过去）。\n",
    "\n",
    "但，当序列过长，需要分不同批次输入时，则无需重置RNN的状态(上一批的结束状态，为下一批的初始状态)。这样，即使一次只看到一个子序列，网络层也可以保留整个序列的信息。\n",
    "\n",
    "我们可以置state_ful=True来设置跨批次状态。\n",
    "\n",
    "如果有序列s = [t0, t1, ... t1546, t1547]， 可以将其分为以下批次数据：\n",
    "\n",
    "\n",
    "s1 = [t0, t1, ... t100]\n",
    "\n",
    "s2 = [t101, ... t201]\n",
    "\n",
    "...\n",
    "\n",
    "16 = [t1501, ... t1547]\n",
    "\n",
    "具体调用方法如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_sequence = []\n",
    "lstm_layer = layers.LSTM(64, stateful=True)\n",
    "for s in sub_sequence:\n",
    "    output = lstm_layer(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要清除状态时，可以使用layer.reset_states()。\n",
    "\n",
    "注意：在此设置中，i假定给定批次中的样品i是上一个批次中样品的延续。这意味着所有批次应包含相同数量的样本（批次大小）。例如，如果一个批次包含[sequence_A_from_t0_to_t100, sequence_B_from_t0_to_t100]，则下一个批次应包含[sequence_A_from_t101_to_t200, sequence_B_from_t101_to_t200]。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "跨批次状态例子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "para1 = np.random.random((20, 10, 50)).astype(np.float32)\n",
    "para2 = np.random.random((20, 10, 50)).astype(np.float32)\n",
    "para3 = np.random.random((20, 10, 50)).astype(np.float32)\n",
    "lstm_layer = layers.LSTM(64, stateful=True)\n",
    "output = lstm_layer(para1)\n",
    "output = lstm_layer(para2)\n",
    "output = lstm_layer(para3)\n",
    "lstm_layer.reset_states()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 双向LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于时间序列以外的序列，比如文本，RNN不仅可以正向处理序列，也可以反向处理序列。例如要预测句子的某个单纯，上下文对单词都有用。\n",
    "\n",
    "Keras提供了tf.keras.layers.Bidirectional的API，构建双向RNN。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_3\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "bidirectional (Bidirectional (None, 5, 128)            38400     \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 64)                41216     \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 10)                650       \n",
      "=================================================================\n",
      "Total params: 80,266\n",
      "Trainable params: 80,266\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(layers.Bidirectional(layers.LSTM(64, return_sequences=True),\n",
    "                              input_shape=(5, 10)))\n",
    "model.add(layers.Bidirectional(layers.LSTM(32)))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在内部， Bidirectional将复制传入的RNN，并将逆序输入新复制的RNN，并输出前向输出和后向输出的叠加。如果想要其他合并行为（如串联），可以修改merge_mode等参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6 TensorFlow2.0中的性能优化和CuDNN内核\n",
    "\n",
    "在TensorFlow2.0中，内置的LSTM和GRU层已更新，已在有GPU时默认使用CuDNN内核。通过此更改，先前的keras.layers.CuDNNLSTM/CuDNNGRU层已经弃用，可以简易的构建模型而不必担心其运行的硬件。\n",
    "\n",
    "由于CuDNN内核是根据某些假设构建的，因此如果改变内置LSTM或GRU的默认设置，则该层无法使用CuDNN内核，例如:\n",
    "\n",
    "- 将activation从tanh改为其他激活函数。\n",
    "- 将recurrent_activation由sigmoid改为其他激活函数。\n",
    "- 使recurrent_dropout> 0。\n",
    "- 设置unroll为True，将强制LSTM / GRU将内部分解tf.while_loop为展开的for循环。\n",
    "- 设置use_bias为False。\n",
    "- 当输入数据未严格右填充时使用掩蔽（如果掩码对应于严格右填充数据，则仍可以使用CuDNN。这是最常见的情况）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 在可用时使用CuDNN内核\n",
    "我们构建一个简单的LSTM网络，来实现MNIST数字识别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "input_dim = 28\n",
    "units = 64\n",
    "output_size = 10\n",
    "\n",
    "def build_model(allow_cudnn_kernel=True):\n",
    "    if allow_cudnn_kernel:\n",
    "        # LSTM默认cudnn加速\n",
    "        lstm_layer = tf.keras.layers.LSTM(units, input_shape=(None, input_dim))\n",
    "    else:\n",
    "        # LSTMCell内核没有使用\n",
    "        lstm_layer = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(units),\n",
    "                                        input_shape=(None, input_dim))\n",
    "    model = tf.keras.models.Sequential([\n",
    "        lstm_layer,\n",
    "        tf.keras.layers.BatchNormalization(),\n",
    "        tf.keras.layers.Dense(output_size, activation='softmax')\n",
    "    ])\n",
    "    return model\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载MNIST数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnits = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnits.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "sample, sample_label = x_train[0], y_train[0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建模型实例并进行编译\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(allow_cudnn_kernel=True)\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "             optimizer='sgd',\n",
    "             metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/2\n",
      "60000/60000 [==============================] - 10s 169us/sample - loss: 0.1457 - accuracy: 0.9560 - val_loss: 0.1290 - val_accuracy: 0.9566\n",
      "Epoch 2/2\n",
      "60000/60000 [==============================] - 10s 170us/sample - loss: 0.1284 - accuracy: 0.9614 - val_loss: 0.1394 - val_accuracy: 0.9543\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f94d0c8e7b8>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用cudnn的训练\n",
    "model.fit(x_train, y_train,\n",
    "         validation_data=(x_test, y_test),\n",
    "         batch_size=batch_size,\n",
    "         epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在没有CuDNN内核的情况下构建新模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/2\n",
      "60000/60000 [==============================] - 11s 185us/sample - loss: 0.9310 - accuracy: 0.7030 - val_loss: 0.5106 - val_accuracy: 0.8419\n",
      "Epoch 2/2\n",
      "60000/60000 [==============================] - 10s 172us/sample - loss: 0.4209 - accuracy: 0.8715 - val_loss: 0.4863 - val_accuracy: 0.8416\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f94f31478d0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slow_model = build_model(allow_cudnn_kernel=False)\n",
    "#slow_model.set_weights(model.get_weights())\n",
    "slow_model.compile(loss='sparse_categorical_crossentropy',\n",
    "             optimizer='sgd',\n",
    "             metrics=['accuracy'])\n",
    "slow_model.fit(x_train, y_train,\n",
    "              validation_data=(x_test, y_test),\n",
    "              batch_size=batch_size,\n",
    "              epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用cudnn的模型也可以在cpu环境中进行推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测结果: [5], 正确标签: 5\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnEYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKIWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPRDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8HoInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4y5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XVtDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XUU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YANEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYffzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enTpyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9YceeihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+ppDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlAMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCapWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urVq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23JOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeHh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6kvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/Pll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7KrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFrkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oya9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X57LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbSu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5JecvdrJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5kk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsaG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nkk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93V6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHEE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kfGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+QzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjVhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHkquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2u/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2jR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5jZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8PoCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynDzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCzdKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710tM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXyvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+IiSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with tf.device('CPU:0'):\n",
    "    cpu_model = build_model(allow_cudnn_kernel=True)\n",
    "    cpu_model.set_weights(model.get_weights())\n",
    "    result = tf.argmax(cpu_model.predict_on_batch(tf.expand_dims(sample, 0)), axis=1)\n",
    "    print('预测结果: %s, 正确标签: %s' % (result.numpy(), sample_label))\n",
    "    plt.imshow(sample, cmap=plt.get_cmap('gray'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7.具有列表/字典输入或嵌套输入的RNN\n",
    "嵌套结构可以在单个时间步内包含更多信息。例如，一个视频帧可以同时具有音频和视频输入。比如以下的数据格式：\n",
    "\n",
    "[batch, timestep, {\"video\": [height, width, channel], \"audio\": [frequency]}]\n",
    "\n",
    "另一个例子，如笔迹数据。可以有当前的位置坐标x，y和压力信息。格式如下：\n",
    "\n",
    "[batch, timestep, {\"location\": [x, y], \"pressure\": [force]}]\n",
    "\n",
    "### 7.1 定义一个支持嵌套输入/输出的自定义单元格\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "NestedInput = collections.namedtuple('NestedInput', ['feature1', 'feature2'])\n",
    "NestedState = collections.namedtuple('NestedState', ['state1', 'state2'])\n",
    "\n",
    "class NestedCell(tf.keras.layers.Layer):\n",
    "    # 初始化，获取相关参数\n",
    "    def __init__(self,unit1, unit2, unit3, **kwargs):\n",
    "        self.unit1 = unit1\n",
    "        self.unit2 = unit2\n",
    "        self.unit3 = unit3\n",
    "        self.state_size = NestedState(state1=unit1,\n",
    "                                     state2=tf.TensorShape([unit2, unit3]))\n",
    "        \n",
    "        self.output_size = (unit1, tf.TensorShape([unit2, unit3]))\n",
    "        super(NestedCell, self).__init__(**kwargs)\n",
    "    # 构建权重、网络\n",
    "    def build(self, input_shapes):\n",
    "        # input_shape包含2个特征项 [(batch, i1), (batch, i2, i3)]\n",
    "        input1 = input_shapes.feature1[1]\n",
    "        input2, input3 = input_shapes.feature2[1:]\n",
    "        \n",
    "        self.kernel_1 = self.add_weight(\n",
    "            shape=(input1, self.unit1), initializer='uniform', name='kernel_1'\n",
    "        )\n",
    "        self.kernel_2_3 = self.add_weight(\n",
    "            shape=(input2, input3, self.unit2, self.unit3),\n",
    "            initializer='uniform',\n",
    "            name='kernel_2_3'\n",
    "        )\n",
    "    # 前向连接网络\n",
    "    def call(self, inputs, states):\n",
    "        # inputs: [(batch, input_1), (batch, input_2, input_3)]\n",
    "        # state: [(batch, unit_1), (batch, unit_2, unit_3)]\n",
    "        input1, input2 = tf.nest.flatten(inputs)\n",
    "        s1, s2 = states\n",
    "        \n",
    "        output_1 = tf.matmul(input1, self.kernel_1)\n",
    "        output_2_3 = tf.einsum('bij,ijkl->bkl', input2, self.kernel_2_3)\n",
    "        \n",
    "        state_1 = s1 + output_1\n",
    "        state_2_3 = s2 + output_2_3\n",
    "        \n",
    "        output = [output_1, output_2_3]\n",
    "        new_states = NestedState(state1=state_1, state2=state_2_3)\n",
    "        return output, new_states\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 使用嵌套的输入输出构建RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "unit_1 = 10\n",
    "unit_2 = 20\n",
    "unit_3 = 30\n",
    "\n",
    "input_1 = 32\n",
    "input_2 = 64\n",
    "input_3 = 32\n",
    "batch_size = 64\n",
    "num_batch = 100\n",
    "timestep = 50\n",
    "cell = NestedCell(unit_1, unit_2, unit_3)\n",
    "rnn = tf.keras.layers.RNN(cell)\n",
    "input1 = tf.keras.Input((None, input_1))\n",
    "input2 = tf.keras.Input((None, input_2, input_3))\n",
    "outputs = rnn(NestedInput(feature1=input1, feature2=input2))\n",
    "model = tf.keras.models.Model([input1, input2], outputs)\n",
    "model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造数据训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 6400 samples\n",
      "6400/6400 [==============================] - 191s 30ms/sample - loss: 0.3808 - rnn_7_loss: 0.1197 - rnn_7_1_loss: 0.2611 - rnn_7_accuracy: 0.1003 - rnn_7_1_accuracy: 0.0333\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f94c9ee4ba8>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_1_data = np.random.random((batch_size * num_batch, timestep, input_1))\n",
    "input_2_data = np.random.random((batch_size * num_batch, timestep, input_2, input_3))\n",
    "target_1_data = np.random.random((batch_size * num_batch, unit_1))\n",
    "target_2_data = np.random.random((batch_size * num_batch, unit_2, unit_3))\n",
    "input_data = [input_1_data, input_2_data]\n",
    "target_data = [target_1_data, target_2_data]\n",
    "\n",
    "model.fit(input_data, target_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
